{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0052f074",
   "metadata": {},
   "source": [
    "# Privacy-Preserving Machine Learning on Titanic\n",
    "\n",
    "This notebook introduces a Privacy-Preserving Machine Learning (PPML) solution to the [Kaggle Titanic competition](https://www.kaggle.com/c/titanic/) using the [Concrete ML](https://docs.zama.ai/concrete-ml) open-source framework. Its main ambition is to show that [Fully Homomorphic Encryption](https://en.wikipedia.org/wiki/Homomorphic_encryption) (FHE) can be used for protecting data when using a Machine Learning model to predict outcomes without degrading its performance. In this example, a [XGBoost](https://xgboost.readthedocs.io/en/release_3.0.0/) classifier model will be considered as it achieves near state-of-the-art accuracy.\n",
    "\n",
    "It also took some ideas from several upvoted public notebooks, including [this one](https://www.kaggle.com/code/startupsci/titanic-data-science-solutions/notebook) from Manav Sehgal and [this one](https://www.kaggle.com/code/ldfreeman3/a-data-science-framework-to-achieve-99-accuracy#Step-3:-Prepare-Data-for-Consumption) from LD Freeman."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3e2ca58",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7c415ea3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import GridSearchCV, ShuffleSplit, train_test_split\n",
    "from xgboost import XGBClassifier\n",
    "\n",
    "from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ba063203",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['pclass', 'survived', 'name', 'sex', 'age', 'sibsp', 'parch', 'ticket',\n",
      "       'fare', 'cabin', 'embarked', 'boat', 'body', 'home.dest'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "# pylint: disable-next=no-member\n",
    "all_data = fetch_openml(data_id=40945, as_frame=True, cache=True).frame\n",
    "\n",
    "all_data = all_data.convert_dtypes()\n",
    "all_data[\"survived\"] = all_data[\"survived\"].astype(int)\n",
    "print(all_data.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b79c0e8",
   "metadata": {},
   "source": [
    "## Preprocessing the Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5121bf3e",
   "metadata": {},
   "source": [
    "Be sure to launch the `download_data.sh` script in order to have local versions of the data-set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f5f3802c",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, test_data = train_test_split(all_data, test_size=len(all_data.index) - 891)\n",
    "datasets = [train_data, test_data]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7ec9783",
   "metadata": {},
   "source": [
    "Let's take a closer look at the train data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ae653f6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pclass</th>\n",
       "      <th>survived</th>\n",
       "      <th>name</th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>ticket</th>\n",
       "      <th>fare</th>\n",
       "      <th>cabin</th>\n",
       "      <th>embarked</th>\n",
       "      <th>boat</th>\n",
       "      <th>body</th>\n",
       "      <th>home.dest</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>995</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>Markoff, Mr. Marin</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>349213</td>\n",
       "      <td>7.8958</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>C</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1147</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>Riihivouri, Miss. Susanna Juhantytar 'Sanni'</td>\n",
       "      <td>female</td>\n",
       "      <td>22.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>3101295</td>\n",
       "      <td>39.6875</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>S</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>485</th>\n",
       "      <td>2</td>\n",
       "      <td>0.0</td>\n",
       "      <td>Levy, Mr. Rene Jacques</td>\n",
       "      <td>male</td>\n",
       "      <td>36.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>SC/Paris 2163</td>\n",
       "      <td>12.875</td>\n",
       "      <td>D</td>\n",
       "      <td>C</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>Montreal, PQ</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1265</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>Van Impe, Miss. Catharina</td>\n",
       "      <td>female</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>345773</td>\n",
       "      <td>24.15</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>S</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>632</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>Andersson, Mrs. Anders Johan (Alfrida Konstant...</td>\n",
       "      <td>female</td>\n",
       "      <td>39.0</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>347082</td>\n",
       "      <td>31.275</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>S</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>Sweden Winnipeg, MN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      pclass  survived                                               name   \n",
       "995        3       0.0                                 Markoff, Mr. Marin  \\\n",
       "1147       3       0.0       Riihivouri, Miss. Susanna Juhantytar 'Sanni'   \n",
       "485        2       0.0                             Levy, Mr. Rene Jacques   \n",
       "1265       3       0.0                          Van Impe, Miss. Catharina   \n",
       "632        3       0.0  Andersson, Mrs. Anders Johan (Alfrida Konstant...   \n",
       "\n",
       "         sex   age  sibsp  parch         ticket     fare cabin embarked  boat   \n",
       "995     male  35.0      0      0         349213   7.8958  <NA>        C  <NA>  \\\n",
       "1147  female  22.0      0      0        3101295  39.6875  <NA>        S  <NA>   \n",
       "485     male  36.0      0      0  SC/Paris 2163   12.875     D        C  <NA>   \n",
       "1265  female  10.0      0      2         345773    24.15  <NA>        S  <NA>   \n",
       "632   female  39.0      1      5         347082   31.275  <NA>        S  <NA>   \n",
       "\n",
       "      body            home.dest  \n",
       "995   <NA>                 <NA>  \n",
       "1147  <NA>                 <NA>  \n",
       "485   <NA>         Montreal, PQ  \n",
       "1265  <NA>                 <NA>  \n",
       "632   <NA>  Sweden Winnipeg, MN  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c26d9ae",
   "metadata": {},
   "source": [
    "We can observe:\n",
    "- the target variable: Survived\n",
    "- some numerical variables: Pclass, SbSp, Parch, Fare\n",
    "- some categorical (non-numerical) variables: Name, Sex, Ticket, Cabin, Embarked"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7259b1ab",
   "metadata": {},
   "source": [
    "### Missing Values\n",
    "\n",
    "Then, we can notice that some values are missing for the Cabin variable. We must therefore investigate a bit more about this by printing the total amounts of missing values for each variables:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "737d6fdb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train data ---\n",
      "pclass         0\n",
      "survived       0\n",
      "name           0\n",
      "sex            0\n",
      "age          173\n",
      "sibsp          0\n",
      "parch          0\n",
      "ticket         0\n",
      "fare           0\n",
      "cabin        691\n",
      "embarked       1\n",
      "boat         556\n",
      "body         806\n",
      "home.dest    378\n",
      "dtype: int64 \n",
      "\n",
      "--- Test data ---\n",
      "pclass         0\n",
      "survived       0\n",
      "name           0\n",
      "sex            0\n",
      "age           90\n",
      "sibsp          0\n",
      "parch          0\n",
      "ticket         0\n",
      "fare           1\n",
      "cabin        323\n",
      "embarked       1\n",
      "boat         267\n",
      "body         382\n",
      "home.dest    186\n",
      "dtype: int64\n"
     ]
    }
   ],
   "source": [
    "print(\"-\" * 3, \"Train data\", \"-\" * 3)\n",
    "print(train_data.isnull().sum(), \"\\n\")\n",
    "print(\"-\" * 3, \"Test data\", \"-\" * 3)\n",
    "print(test_data.isnull().sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f945a8f5",
   "metadata": {},
   "source": [
    "Only four variables are incomplete: Cabin, Age, Embarked and Fare. However, the Cabin one seems to be missing quite more data than the others:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0266c622",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Percentage of missing values in age: 20.1%\n",
      "Percentage of missing values in fare: 0.1%\n",
      "Percentage of missing values in cabin: 77.5%\n",
      "Percentage of missing values in embarked: 0.2%\n",
      "Percentage of missing values in boat: 62.9%\n",
      "Percentage of missing values in body: 90.8%\n",
      "Percentage of missing values in home.dest: 43.1%\n"
     ]
    }
   ],
   "source": [
    "for incomp_var in train_data.columns:\n",
    "    missing_val = pd.concat(datasets)[incomp_var].isnull().sum()\n",
    "    if missing_val > 0 and incomp_var != \"survived\":\n",
    "        total_val = pd.concat(datasets).shape[0]\n",
    "        print(f\"Percentage of missing values in {incomp_var}: {missing_val/total_val*100:.1f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e51794a8",
   "metadata": {},
   "source": [
    "Since the Cabin variable misses more than 2/3 of its values, it might not be relevant to keep it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "35d9ceb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "drop_column = [\"cabin\", \"boat\", \"body\", \"home.dest\", \"ticket\"]\n",
    "\n",
    "for dataset in datasets:\n",
    "    # Remove irrelevant variables\n",
    "    dataset.drop(drop_column, axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39c3e038",
   "metadata": {},
   "source": [
    "For the other ones, we can replace the missing values using:\n",
    "- the median value for Age and Fare\n",
    "- the most common value for Embarked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "565ec656",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in datasets:\n",
    "    # Complete missing Age values with median\n",
    "    dataset.age.fillna(dataset.age.median(), inplace=True)\n",
    "\n",
    "    # Complete missing Embarked values with the most common one\n",
    "    dataset.embarked.fillna(dataset.embarked.mode()[0], inplace=True)\n",
    "\n",
    "    # Complete missing Fare values with median\n",
    "    dataset.fare.fillna(dataset.fare.median(), inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61c84a2b",
   "metadata": {},
   "source": [
    "### Feature Engineering\n",
    "\n",
    "We can manually extract and create new features in order to help the model interpret some behaviors and correctly predict an outcome. Those new features are:\n",
    "- FamilySize: The size of the family the individual was traveling with, with 1 being someone that traveled alone. \n",
    "- IsAlone: A boolean variable stating if the individual was traveling alone (1) or not (0). This might help the model to emphasize on this idea of traveling with relatives or not.\n",
    "- Title: The individual's title (Mr, Mrs, ...), often indicating a certain social status.\n",
    "- Farebin and AgeBin: Binned version of the Fare and Age variables. It groups values together, generally reducing the impact of minor observation errors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ff3e032d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_bin_labels(bin_name, number_of_bins):\n",
    "    labels = []\n",
    "    for i in range(number_of_bins):\n",
    "        labels.append(bin_name + f\"_{i}\")\n",
    "    return labels\n",
    "\n",
    "\n",
    "for dataset in datasets:\n",
    "    # Emphasize on relatives\n",
    "    dataset[\"FamilySize\"] = dataset.sibsp + dataset.parch + 1\n",
    "\n",
    "    dataset[\"IsAlone\"] = 1\n",
    "    dataset.IsAlone[dataset.FamilySize > 1] = 0\n",
    "\n",
    "    # Consider an individual's social status\n",
    "    dataset[\"Title\"] = dataset.name.str.extract(r\" ([A-Za-z]+)\\.\", expand=False)\n",
    "\n",
    "    # Group fares and ages in \"bins\" or \"buckets\"\n",
    "    dataset[\"FareBin\"] = pd.qcut(dataset.fare, 4, labels=get_bin_labels(\"FareBin\", 4))\n",
    "    dataset[\"AgeBin\"] = pd.cut(dataset.age.astype(int), 5, labels=get_bin_labels(\"AgeBin\", 5))\n",
    "\n",
    "    # Remove now-irrelevant variables\n",
    "    drop_column = [\"name\", \"sibsp\", \"parch\", \"fare\", \"age\"]\n",
    "    dataset.drop(drop_column, axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3b5576e",
   "metadata": {},
   "source": [
    "Let's have a look on the titles' distribution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "638f3f6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Title\n",
      "Mr          757\n",
      "Miss        260\n",
      "Mrs         197\n",
      "Master       61\n",
      "Dr            8\n",
      "Rev           8\n",
      "Col           4\n",
      "Ms            2\n",
      "Major         2\n",
      "Mlle          2\n",
      "Lady          1\n",
      "Mme           1\n",
      "Countess      1\n",
      "Dona          1\n",
      "Sir           1\n",
      "Jonkheer      1\n",
      "Capt          1\n",
      "Don           1\n",
      "Name: count, dtype: Int64\n"
     ]
    }
   ],
   "source": [
    "data = pd.concat(datasets)\n",
    "titles = data.Title.value_counts()\n",
    "print(titles)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f484bfd4",
   "metadata": {},
   "source": [
    "We can clearly observe that only a few titles represent most of the individuals. In order to prevent the model from becoming overly specific, we decide to group all the \"uncommon\" titles together:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "498273da",
   "metadata": {},
   "outputs": [],
   "source": [
    "uncommon_titles = titles[titles < 10].index\n",
    "\n",
    "for dataset in datasets:\n",
    "    dataset.Title.replace(uncommon_titles, \"Rare\", inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18937e31",
   "metadata": {},
   "source": [
    "### Dummification\n",
    "\n",
    "Finally, we can \"dummify\" the remaining categorical variables. Dummification is a common technique of transforming categorical (non-numerical) data into numerical data without having to map values or consider any order between each of them. The idea is to take all the different values found in a variable and create a new associated binary variable. \n",
    "\n",
    "For example, the \"Embarked\" variable has three categorical values: \"S\", \"C\" and \"Q\". Dummifying the data will create three new variables, \"Embarked_S\", \"Embarked_C\" and \"Embarked_Q\", and set the value of \"Embarked_S\" (resp. \"Embarked_C\" and \"Embarked_Q\") to 1 for each data point initially labeled with \"S\" (resp. \"C\" and \"Q\") in the \"Embarked\" variable, else 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1b0539c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "categorical_features = train_data.select_dtypes(exclude=np.number).columns\n",
    "x_train = pd.get_dummies(train_data, prefix=categorical_features)\n",
    "x_test = pd.get_dummies(test_data, prefix=categorical_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdf3e7f9",
   "metadata": {},
   "source": [
    "We then split the target variable from the others."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "457b1c5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "target = \"survived\"\n",
    "x_train = x_train.drop(columns=[target])\n",
    "y_train = train_data[target]\n",
    "\n",
    "x_test = x_test.drop(columns=[target])\n",
    "y_test = test_data[target]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f6b246c",
   "metadata": {},
   "source": [
    "## Training \n",
    "### Training with XGBoost\n",
    "\n",
    "Let's first train a classifier model using XGBoost. Since several parameters have to be fixed beforehand, we use scikit-learn's GridSearchCV method to perform cross validation in order to maximize our chance to find the best ones. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2cc269e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best hyper-parameters found in 7.98s : {'learning_rate': 1, 'max_depth': 2, 'n_estimators': 4}\n"
     ]
    }
   ],
   "source": [
    "# Instantiate the Cross-Validation generator\n",
    "cv = ShuffleSplit(n_splits=5, test_size=0.3, random_state=0)\n",
    "\n",
    "# Set the parameters to tune.\n",
    "# Those ranges are voluntarily small in order to keep the FHE execution time per inference\n",
    "# relatively low. In fact, we found out that, in this particular Titanic example, models with\n",
    "# larger numbers of estimators or maximum depth don't score a much better accuracy.\n",
    "param_grid = {\n",
    "    \"max_depth\": list(range(1, 5)),\n",
    "    \"n_estimators\": list(range(1, 5)),\n",
    "    \"learning_rate\": [0.01, 0.1, 1],\n",
    "}\n",
    "\n",
    "# Instantiate and fit the model through grid-search cross-validation\n",
    "time_begin = time.time()\n",
    "model = GridSearchCV(XGBClassifier(), param_grid, cv=cv, scoring=\"roc_auc\")\n",
    "model.fit(x_train, y_train)\n",
    "cv_xgb_duration = time.time() - time_begin\n",
    "\n",
    "print(f\"Best hyper-parameters found in {cv_xgb_duration:.2f}s :\", model.best_params_)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfde39b3",
   "metadata": {},
   "source": [
    "### Training with Concrete ML\n",
    "\n",
    "Now, let's do the same with Concrete ML's XGBClassifier method. \n",
    "\n",
    "In order to do so, we need to specify the number of bits over which inputs, outputs and weights will be quantized. This value can influence the precision of the model as well as its inference running time, and therefore can lead the grid-search cross-validation to find a different set of parameters. In our case, setting this value to 2 bits outputs an excellent accuracy score while running faster. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9cc1cc89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best hyper-parameters found in 40.99s : {'learning_rate': 0.1, 'max_depth': 4, 'n_bits': 2, 'n_estimators': 4}\n"
     ]
    }
   ],
   "source": [
    "# The Concrete ML model needs an additional parameter used for quantization\n",
    "param_grid[\"n_bits\"] = [2]\n",
    "\n",
    "# Instantiate and fit the model through grid-search cross-validation\n",
    "time_begin = time.time()\n",
    "concrete_model = GridSearchCV(ConcreteXGBClassifier(), param_grid, cv=cv, scoring=\"roc_auc\")\n",
    "concrete_model.fit(x_train, y_train)\n",
    "cv_concrete_duration = time.time() - time_begin\n",
    "\n",
    "print(f\"Best hyper-parameters found in {cv_concrete_duration:.2f}s :\", concrete_model.best_params_)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a40aba09",
   "metadata": {},
   "source": [
    "## Predicting the Outcomes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2c109f4",
   "metadata": {},
   "source": [
    "Computing the predictions in FHE on the complete test set of 418 data points using the above hyper-parameters can take up to 1 minute, using a [c5.4xlarge AWS instance](https://aws.amazon.com/ec2/instance-types/c5/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "749c026d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Key generation time: 0.94s\n",
      "Total execution time for 418 inferences: 65.49s\n",
      "Execution time per inference in FHE: 0.16s\n"
     ]
    }
   ],
   "source": [
    "# Compute the predictions in clear using XGBoost\n",
    "clear_predictions = model.predict(x_test)\n",
    "\n",
    "# Compute the predictions in clear using Concrete ML\n",
    "clear_quantized_predictions = concrete_model.predict(x_test)\n",
    "\n",
    "# Compile the Concrete ML model on a subset\n",
    "fhe_circuit = concrete_model.best_estimator_.compile(x_train.head(100))\n",
    "\n",
    "# Generate the keys\n",
    "# This step is not absolutely necessary, as keygen() is called, when necessary,\n",
    "# within the predict method.\n",
    "# However, it is useful to run it beforehand in order to be able to\n",
    "# measure the prediction executing time separately from the key generation one\n",
    "time_begin = time.time()\n",
    "fhe_circuit.keygen()\n",
    "key_generation_duration = time.time() - time_begin\n",
    "\n",
    "# Compute the predictions in FHE using Concrete ML\n",
    "time_begin = time.time()\n",
    "fhe_predictions = concrete_model.best_estimator_.predict(x_test, fhe=\"execute\")\n",
    "prediction_duration = time.time() - time_begin\n",
    "\n",
    "print(f\"Key generation time: {key_generation_duration:.2f}s\")\n",
    "print(f\"Total execution time for {len(clear_predictions)} inferences: {prediction_duration:.2f}s\")\n",
    "print(f\"Execution time per inference in FHE: {prediction_duration / len(clear_predictions):.2f}s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a88424d",
   "metadata": {},
   "source": [
    "FHE computations are expected to be exact. This means that the model executed in FHE results in the same predictions as the Concrete ML one, which is executed in clear and only considers quantization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "82bb69dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction similarity between both Concrete ML models (quantized clear and FHE): 100.00%\n",
      "Accuracy of prediction in FHE on the test set 78.23%\n"
     ]
    }
   ],
   "source": [
    "number_of_equal_preds = np.sum(fhe_predictions == clear_quantized_predictions)\n",
    "pred_similarity = number_of_equal_preds / len(clear_predictions) * 100\n",
    "print(\n",
    "    \"Prediction similarity between both Concrete ML models (quantized clear and FHE): \"\n",
    "    f\"{pred_similarity:.2f}%\",\n",
    ")\n",
    "\n",
    "accuracy_fhe = np.mean(fhe_predictions == y_test) * 100\n",
    "print(\n",
    "    \"Accuracy of prediction in FHE on the test set \" f\"{accuracy_fhe:.2f}%\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95331f92",
   "metadata": {},
   "source": [
    "However, as seen previously, the grid-search cross-validation was done separately between the XGBoost model and the Concrete ML one. For this reason, the two models do not share the same set of hyper-parameters, making their decision boundaries different.\n"
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 10800
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
